{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-24T11:42:01.682182Z",
     "start_time": "2020-04-24T11:42:01.236039Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# torch.scatter_散射(torch.scatter)\n",
    "\n",
    "scatter_(input, dim, index, src)将src中数据根据index中的索引按照dim的方向填进input中。\n",
    "\n",
    "\n",
    "```py\n",
    "y[index[i][j][k]][j][k] = x[i][j][k]  # if dim == 0\n",
    "y[i][index[i][j][k]][k] = x[i][j][k]  # if dim == 1\n",
    "y[i][j][index[i][j][k]] = x[i][j][k]  # if dim == 2\n",
    "```\n",
    "\n",
    "\n",
    "当dim=0时, index的值是y的dim=0的坐标"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-24T11:42:42.256326Z",
     "start_time": "2020-04-24T11:42:42.226659Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[0., 0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0., 0.]]),\n",
       " tensor([[  1.,   2.,   3.,   4.,   5.],\n",
       "         [  6.,   7.,   8.,   9., 100.],\n",
       "         [ 11.,  12.,  13.,  14.,  15.]]),\n",
       " tensor([[1, 2, 2, 0, 0],\n",
       "         [2, 0, 0, 1, 2]]),\n",
       " tensor([[  0.,   7.,   8.,   4.,   5.],\n",
       "         [  1.,   0.,   0.,   9.,   0.],\n",
       "         [  6.,   2.,   3.,   0., 100.]]),\n",
       " tensor(1.),\n",
       " tensor(1.),\n",
       " tensor(7.),\n",
       " tensor(7.))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.tensor([[1,2,3,4,5],[6,7,8,9,100],[11,12,13,14,15]], dtype=torch.float)\n",
    "index = torch.tensor([[1,2,2,0,0],[2,0,0,1,2]])\n",
    "y = torch.zeros(3, 5)\n",
    "\n",
    "z = torch.scatter(y, dim=0, index=index, src=x)\n",
    "y, x, index, z, z[index[0][0]][0], x[0][0], z[index[1][1]][1], x[1][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-24T11:42:44.572977Z",
     "start_time": "2020-04-24T11:42:44.533507Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([3])\n",
      "torch.Size([3])\n",
      "3 3\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor([[0., 0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0., 0.]]),\n",
       " tensor([[  1.,   2.,   3.,   4.,   5.],\n",
       "         [  6.,   7.,   8.,   9., 100.],\n",
       "         [ 11.,  12.,  13.,  14.,  15.]]),\n",
       " tensor([[0, 1, 2, 0, 0],\n",
       "         [2, 0, 0, 1, 2],\n",
       "         [0, 0, 0, 3, 3]]),\n",
       " tensor([[1.2000, 1.2000, 1.2000, 0.0000, 0.0000],\n",
       "         [1.2000, 1.2000, 1.2000, 0.0000, 0.0000],\n",
       "         [1.2000, 0.0000, 0.0000, 1.2000, 0.0000]]),\n",
       " torch.Size([3, 5]),\n",
       " tensor(1.2000))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(x[:, 0, ...].shape)\n",
    "print(y[:, 0, ...].shape) # 3 != 2\n",
    "print(x.size(0), y.size(0))\n",
    "try:\n",
    "    index = torch.tensor([[0,1,2,0,0],[2,0,0,1,2]])\n",
    "    z = torch.scatter(y, dim=1, index=index, value=1.2)\n",
    "except:\n",
    "    index = torch.tensor([[0,1,2,0,0],[2,0,0,1,2],[0,0,0,3,3]])\n",
    "    z = torch.scatter(y, dim=1, index=index, value=1.2)\n",
    "    \n",
    "y, x, index, z, z.shape, z[0][index[0][0]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-24T11:42:45.042473Z",
     "start_time": "2020-04-24T11:42:45.022441Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[0., 0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0., 0.]]),\n",
       " tensor([[ 1.,  2.,  3.,  4.,  5.],\n",
       "         [ 6.,  7.,  8.,  9.,  0.],\n",
       "         [11., 12., 13., 14., 15.]]),\n",
       " tensor([[2, 1, 1, 0, 1],\n",
       "         [2, 0, 0, 1, 2],\n",
       "         [0, 0, 0, 3, 3]]),\n",
       " tensor([[ 4.,  5.,  1.,  0.,  0.],\n",
       "         [ 8.,  9.,  0.,  0.,  0.],\n",
       "         [13.,  0.,  0., 15.,  0.]]))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.tensor([[1,2,3,4,5],[6,7,8,9,0],[11,12,13,14,15]], dtype=torch.float)\n",
    "index = torch.tensor([[2,1,1,0,1],[2,0,0,1,2],[0,0,0,3,3]])\n",
    "\n",
    "y = torch.zeros(3, 5)\n",
    "z = torch.scatter(y, dim=1, index=index, src=x)\n",
    "\n",
    "y, x, index, z"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## one-hot\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-26T08:43:14.479821Z",
     "start_time": "2020-04-26T08:43:14.466156Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[0., 0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0., 0.],\n",
       "         [0., 0., 0., 0., 0.]]),\n",
       " torch.Size([5, 5]),\n",
       " tensor([[0, 1, 2, 3, 4]]))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y = torch.zeros(5,5,dtype=torch.float32)\n",
    "index = torch.tensor([[0,1,2,3,4]],dtype=torch.int64)\n",
    "y, y.shape, index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-24T11:42:01.821206Z",
     "start_time": "2020-04-24T11:42:01.165Z"
    }
   },
   "outputs": [],
   "source": [
    "z = torch.scatter(y, dim=0, index=index, value=1.0)\n",
    "z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# torch.gather (收集)\n",
    "\n",
    "根据index和dim将x中的数据挑选出来，放置到y中\n",
    "\n",
    "```py\n",
    "y[i][j][k] = x[index[i][j][k]][j][k]  # if dim == 0\n",
    "y[i][j][k] = x[i][index[i][j][k]][k]  # if dim == 1\n",
    "y[i][j][k] = x[i][j][index[i][j][k]]  # if dim == 2\n",
    "```\n",
    "\n",
    "\n",
    "```py\n",
    "y[i][j] = x[index[i][j]][j]  # if dim == 0\n",
    "y[i][j] = x[i][index[i][j]]  # if dim == 1\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-24T11:42:01.822367Z",
     "start_time": "2020-04-24T11:42:01.167Z"
    }
   },
   "outputs": [],
   "source": [
    "x = torch.tensor([[1,2],[3,4]])\n",
    "index = torch.tensor([[0,0],[1,0]])\n",
    "y = torch.gather(input=x, dim=1, index=index)\n",
    "x, index, y, y[0][1], x[0][index[0][1]]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
